import jieba
from tensorflow import keras
from nlu_model.sim.model.dssm_base import DssmBase
from nlu_model.sim.model.cdssm_base import CdssmBase
from nlu_model.sim.model.lstm_dssm_base import LstmDssmBase

class SimModel(object):
    """docstring for ClsModel"""
    def __init__(self, model_choice, model_conf={}, train_conf={}):
        self.model_choice = model_choice
        self.model_conf = model_conf
        self.train_conf = train_conf
        self.__model_select__()

    def __model_select__(self):
        if self.model_choice == "dssm_base":
            self.model = DssmBase(self.model_conf, self.train_conf)
        if self.model_choice == "cdssm_base":
            self.model = CdssmBase(self.model_conf, self.train_conf)
        if self.model_choice == "lstm_dssm_base":
            self.model = LstmDssmBase(self.model_conf, self.train_conf)
        if self.model_choice == "load":
            self.load(self.model_conf["path"])

    def preprocess(self, sentences):
        sentences = [list(jieba.cut(i)) for i in sentences]
        sentence_id = self.model_conf["emb_model"].batch2idx(sentences)
        return keras.preprocessing.sequence.pad_sequences(sentence_id,
                                                          value=0,
                                                          padding='post',
                                                          maxlen=50)

    def fit(self, x_train, y_train, x_test, y_test):
        return self.model.fit(x_train, y_train, x_test, y_test)

    def evaluate(self, x_test, y_test):
        return self.model.evaluate(x_test, y_test)

    def pred(self, sentence):
        return self.model.predict(sentence)

    def predict(self, sentences):
        sentence_id = self.preprocess(sentences)
        return self.pred([[sentence_id[0]], [sentence_id[1]]])[0][0]

    def pred_vec(self, sentence):
        sentence_id = self.preprocess(sentence)
        return self.model.predict_vec([sentence_id])

    def save(self, path):
        self.model.save(path)

    def load(self, path):
        self.model = keras.models.load_model(path)


        